import numpy as np


def REF(tp1, n):
    if isinstance(n, np.ndarray):
        return REF_NP(tp1, n)
    else:
        return REF_NUM(tp1, n)


def REF_NP(tp1, n):
    # tp1和n的索引位
    i = 0
    ZB_l = []
    while i < len(tp1):
        index = i - n[i]
        if index < 0: index = 0
        ZB_l.append(tp1[index])
        i += 1

    return np.array(ZB_l)


def REF_NUM(tp1, n):
    i = 0
    ZB_l = []
    while i < n:
        y = list(tp1)[i]
        ZB_l.append(y)
        i = i + 1
    while i < len(tp1):
        y = list(tp1)[i - n]
        ZB_l.append(y)
        i = i + 1
    #    ZB_s = pd.Series(ZB_l)
    return np.array(ZB_l)  # [-1]s


def LLV(s, n):
    return s.rolling(n).min()


def EMA(s, n):
    return s.rolling(n).ema()


def HHV(s, n):
    if type(n) == int:
        return s.rolling(n).max()
    else:  # 当n不是单个值而是一个序列
        result = [0] * len(s)
        n = list(n)
        for i in range(1, len(s)):
            s_temp = s[0:i]
            result[i] = s_temp[-int(n[i]):].max()  # 过去n天
        return result


def CROSS(cond1, cond2):
    '''x1上穿x2'''
    return np.where(eval(cond1) > eval(cond2), 1, 0).tolist()


def barslast(df):
    lst = list(df)
    if sum(lst) > 0:
        first_ = lst.index(1)  # 出现1的所有位置
        bar_slast = []
        for i in range(first_):
            bar_slast.append(np.nan)
        for i in range(first_, len(lst)):  # 出现1后往后计数
            if lst[i] == 1:
                count_ = 0
                bar_slast.append(0)
            else:
                count_ += 1
                bar_slast.append(count_)
        return bar_slast


# @nb.jit
def COUNT(cond, n):
    # TODO lazy compute
    series = cond
    size = len(cond) - n
    try:
        result = np.full(size, 0, dtype=np.int)
    except:
        pass
    for i in range(size - 1, 0, -1):
        s = series[-n:]
        result[i] = len(s[s == True])
        series = series[:-1]
    return result


# @nb.jit
def COUNT_(cond, n):
    if type(n) != int:
        # 两列序列一一 正序对应
        result = [0] * len(cond)
        cond = list(cond)
        n = list(n)
        for i in range(1, len(cond)):
            cond_temp = cond[0:i]
            cond_n_ture = cond_temp[-n[i]:]  # 过去n天
            result[i] = cond_n_ture[cond_n_ture == True]
        return np.array(result)
